import torch
import math
import numpy as np


class NerfEmbedding(torch.nn.Module):
    def __init__(self, sample_function_name, num_features, dimensions, **sample_args):
        '''
        Given a set of random features, the function returns a random projection of the input data
        
        :param sample_function_name: the name of the function that will be used to sample the random
        features
        :param num_features: the number of features in the dataset
        :param dimensions: the number of dimensions of the input data
        '''
        super().__init__()

        sample_function = getattr(self, sample_function_name, 'sample_from_normal')
        self.register_buffer(
            'k', 2 * math.pi * sample_function(num_features, dimensions, **sample_args)
        )
    
        self.register_buffer('b', 2 * math.pi * torch.rand(num_features))
        self.output_features = num_features

    @staticmethod
    def sample_from_ids(num_samples, dimensions, lambdas):
        assert list(dimensions) == list(lambdas.shape)

        prod_lambdas = np.prod(np.array(list(lambdas.shape)))
        indixes = np.arange(prod_lambdas)
        k = np.random.choice(size=(num_samples,), 
                        a=indixes, p=lambdas.flatten())
        return torch.from_numpy(np.vstack(np.unravel_index(k, lambdas.shape)).T).float()

    @staticmethod
    def sample_from_cauchy(num_samples, dimensions, std=1):
        k = torch.empty(num_samples, len(dimensions)).cauchy_(sigma=std)
        # k = jax.random.cauchy(rand_key, shape=(num_samples, len(dimensions))) * std
        return k

    @staticmethod
    def sample_from_laplace(num_samples, dimensions, std=1):
        laplace = torch.distributions.laplace.Laplace(torch.tensor([0.0]), torch.tensor([std]))
        k = laplace.rsample(sample_shape=(num_samples, len(dimensions)))
        # k = jax.random.laplace(rand_key, shape=(num_samples, len(dimensions))) * std
        return k
    
    @staticmethod
    def sample_from_uniform(num_samples, dimensions, rescale_dims=True):
        k = torch.rand(num_samples, len(dimensions))
        if rescale_dims:
            for i, dim_size in enumerate(dimensions):
                k[..., i] *= dim_size
        return k
    
    @staticmethod
    def sample_from_normal(num_samples, dimensions, sigma=1, rescale_dims=True):
        k = torch.randn(num_samples, len(dimensions)) * sigma
        if rescale_dims:
            for i, dim_size in enumerate(dimensions):
                k[..., i] *= dim_size
        return k

    def forward(self, x):
        # print(x.shape, (x @ self.k.T).shape)
        x_rff = math.sqrt(2) * torch.cos(x @ self.k.T + self.b[None, :])
        # print(x_rff.shape)
        x_rff = x_rff.view(*x.shape[:-1], -1)
        return x_rff


class HarmonicEmbedding(torch.nn.Module):
    def __init__(self, num_features=60, omega0=0.1, dimensions=None):
        """
        Given an input tensor `x` of shape [minibatch, ... , dim],
        the harmonic embedding layer converts each feature
        in `x` into a series of harmonic features `embedding`
        as follows:
            embedding[..., i*dim:(i+1)*dim] = [
                sin(x[..., i]),
                sin(2*x[..., i]),
                sin(4*x[..., i]),
                ...
                sin(2**(self.n_harmonic_functions-1) * x[..., i]),
                cos(x[..., i]),
                cos(2*x[..., i]),
                cos(4*x[..., i]),
                ...
                cos(2**(self.n_harmonic_functions-1) * x[..., i])
            ]
            
        Note that `x` is also premultiplied by `omega0` before
        evaluating the harmonic functions.
        """
        super().__init__()
        self.register_buffer(
            'frequencies',
            omega0 * (2.0 ** torch.arange(num_features)),
        )
        self.output_features = num_features * 6 # * dim * 2

    def forward(self, x):
        """
        Args:
            x: tensor of shape [..., dim]
        Returns:
            embedding: a harmonic embedding of `x`
                of shape [..., n_harmonic_functions * dim * 2]
        """
        embed = (x[..., None] * self.frequencies).view(*x.shape[:-1], -1)
        return torch.cat((embed.sin(), embed.cos()), dim=-1)


class FourierBasisEmbedding(torch.nn.Module):
    def __init__(self, num_features=60, dimensions=(1, 1, 1), coef=1.0, symmetic_freq=True):
        """
        Given an input tensor `x` of shape [minibatch, ... , dim],
        the harmonic embedding layer converts each feature
        in `x` into a series of harmonic features `embedding`
        as follows:
            embedding[..., i*dim:(i+1)*dim] = [
                sin(0 * x[..., i]),
                sin(\pi /d * x[..., i]),
                sin(2 \pi / d * x[..., i]),
                ...
                sin((num_features - 1) \pi /d  * x[..., i]),
                cos(0 * x[..., i]),
                cos(\pi /d * x[..., i]),
                cos(2 \pi / d * x[..., i]),
                ...
                cos((num_features - 1) \pi /d  * x[..., i]),
            ]
            
        """
        super().__init__()

        freq = torch.arange(1, num_features + 1)
        if symmetic_freq:
            n = num_features // 2
            freq = torch.linspace(-n, n, 2 * n + 1)
        
        self.register_buffer(
            'k',
            torch.cartesian_prod(*[coef * math.pi * freq / dim for dim in dimensions]),
        )
        self.output_features = num_features * num_features * num_features * 2


    def forward(self, x):
        """
        Args:
            x: tensor of shape [..., dim]
        Returns:
            embedding: a basic embedding of `x`
                of shape [..., num_features ** dim * 2]
        """
        embed = (x @ self.k.T).view(*x.shape[:-1], -1)
        return torch.cat((embed.sin(), embed.cos()), dim=-1)